Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add warning message for beta and gamma parameters #31654

Merged
merged 5 commits into from
Jul 11, 2024

Conversation

OmarManzoor
Copy link
Contributor

What does this PR do?

This adds a warning message to notify about the renaming of gamma and beta parameters during initialisation and also during loading.

Fixes #29554

Before submitting

  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?

Who can review?

@amyeroberts

@OmarManzoor OmarManzoor changed the title Add warning message for and parameters Add warning message for beta and gamma parameters Jun 27, 2024
Comment on lines 3989 to 3993
if "beta" in loaded_keys or "gamma" in loaded_keys:
logger.warning(
f"Parameter names `gamma` or `beta` for {cls.__name__} will be renamed within the model. "
f"Please use different names to suppress this warning."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't this is quite right, this assumes the weight is called "beta" in the state dict, but it could be called "layer.beta"

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @OmarManzoor,

Thanks for addressing this! We want to make sure we catch any place where the renaming happens, so any place where if gamma in key and if beta in key are True (so key can be a longer string that contains beta or gamma). As you've added, this would be in _load_pretrained_model but also in _load_state_dict_into_model

@OmarManzoor
Copy link
Contributor Author

Hi @OmarManzoor,

Thanks for addressing this! We want to make sure we catch any place where the renaming happens, so any place where if gamma in key and if beta in key are True (so key can be a longer string that contains beta or gamma). As you've added, this would be in _load_pretrained_model but also in _load_state_dict_into_model

Hi @amyeroberts
Thanks for the feedback. Should we remove it during initialization? I added it in post init because during the main init we might not have the parameters declared.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @OmarManzoor, thanks for iterating on this!

Given the diff, I'm slightly confused, were there no warnings being triggered before? It seems like they were from the tests and logging messages

Comment on lines +1514 to +1515
warning_msg_gamma = "A parameter name that contains `gamma` will be renamed internally"
model = TestModelGamma(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More importantly, we should check that the parameter is renamed as well

Copy link
Contributor Author

@OmarManzoor OmarManzoor Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried this out and it seems that the parameter is not renamed at all. Basically when we load the model using from_pretrained it seems that the parameter is still present with the name gamma_param.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't rename the value in the model, but will rename the value in the state_dict, I believe. Could you dive into the loading logic and verify what's happening?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried updating the tests. Could you kindly have a look?

src/transformers/modeling_utils.py Outdated Show resolved Hide resolved
@OmarManzoor
Copy link
Contributor Author

Given the diff, I'm slightly confused, were there no warnings being triggered before? It seems like they were from the tests and logging messages

I basically removed the warning code that I added in the post init method. Should that be kept?

@amyeroberts
Copy link
Collaborator

@OmarManzoor Ah, OK. I think the diff was rendering funny on github. Should be OK.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great - thanks for adding and iterating on this!

@amyeroberts amyeroberts merged commit 1499a55 into huggingface:main Jul 11, 2024
20 checks passed
@OmarManzoor
Copy link
Contributor Author

Looks great - thanks for adding and iterating on this!

Thank you.

@OmarManzoor OmarManzoor deleted the warning_for_gamma_beta branch July 11, 2024 12:06
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Jul 19, 2024
* Add warning message for  and  parameters

* Fix when the warning is raised

* Formatting changes

* Improve testing and remove duplicated warning from _fix_key
MHRDYN7 pushed a commit to MHRDYN7/transformers that referenced this pull request Jul 23, 2024
* Add warning message for  and  parameters

* Fix when the warning is raised

* Formatting changes

* Improve testing and remove duplicated warning from _fix_key
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Jul 24, 2024
* Add warning message for  and  parameters

* Fix when the warning is raised

* Formatting changes

* Improve testing and remove duplicated warning from _fix_key
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Can't load models with a gamma or beta parameter
2 participants